"""

Perplexity Metric:
-------------------------------------------------------
Class for calculating perplexity from Jailbreak_Dataset

"""

import torch
from easyjailbreak.metrics.Metric.metric import Metric
from easyjailbreak.datasets import JailbreakDataset
from easyjailbreak.models import WhiteBoxModelBase


class Perplexity(Metric):
    def __init__(self, model:WhiteBoxModelBase, max_length=512, stride=512):
        """
        Initializes the evaluator with a given language model and tokenizer.
        :param model: The WhiteBoxModelBase to be used, which include model and tokenizer.
        :param tokenizer: The tokenizer to be used with the language model.
        :param max_length: The maximum length of tokens for the model. If None, it will be set from the model config.
        :param stride: The stride to be used during tokenization. Default is 512.

        # Example usage:
        # from transformers import GPT2LMHeadModel, GPT2Tokenizer
        # model = GPT2LMHeadModel.from_pretrained("gpt2")
        # tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        # evaluator = LanguageModelEvaluator(model, tokenizer)
        """
        self.all_metrics = {}
        self.prompts = []  # prompts to calculate ppl

        # Initialize model and tokenizer
        self.ppl_model = model.model
        self.ppl_tokenizer = model.tokenizer

        # Set the model to evaluation mode
        self.ppl_model.eval()

        # Set max_length from the model configuration if not provided
        self.max_length = max_length
        self.stride = stride



    def calculate(self, dataset: JailbreakDataset):
        """Calculates average Perplexity on the final prompts generated by attacker using a
        pre-trained small GPT-2 model.

        Args:
            dataset (``Jailbreak_Dataset`` objects):
                list of instances with attack results
        """
        self.dataset = dataset

        for Instance in self.dataset:
            self.prompts.append(Instance.jailbreak_prompt)

        ppl = self.calc_ppl(self.prompts)

        self.all_metrics["avg_prompt_perplexity"] = round(ppl, 2)

        return self.all_metrics

    def calc_ppl(self, texts):
        with torch.no_grad():
            text = " ".join(texts)
            eval_loss = []
            input_ids = torch.tensor(
                self.ppl_tokenizer.encode(text, add_special_tokens=True)
            ).unsqueeze(0)
            # Strided perplexity calculation from huggingface.co/transformers/perplexity.html
            for i in range(0, input_ids.size(1), self.stride):
                begin_loc = max(i + self.stride - self.max_length, 0)
                end_loc = min(i + self.stride, input_ids.size(1))
                trg_len = end_loc - i
                input_ids_t = input_ids[:, begin_loc:end_loc].to(
                    self.ppl_model.device
                )
                target_ids = input_ids_t.clone()
                target_ids[:, :-trg_len] = -100

                outputs = self.ppl_model(input_ids_t, labels=target_ids)
                log_likelihood = outputs[0] * trg_len

                eval_loss.append(log_likelihood)

        return torch.exp(torch.stack(eval_loss).sum() / end_loc).item()
